{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9bf02855-2a2a-433f-bc8a-883abf99ee34",
   "metadata": {},
   "source": [
    "# Perplexity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "61c9de47-7f26-4c77-8c8f-0592716c337b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def calculate_perplexity(probabilities):\n",
    "    log_probs = np.log2(probabilities)\n",
    "    avg_log_prob = np.mean(log_probs)\n",
    "    perplexity = 2 ** (-avg_log_prob)\n",
    "    return perplexity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2d17c850-0867-4196-9477-895b879e975a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perplexity sentence 1: 1.0567214564189926\n"
     ]
    }
   ],
   "source": [
    "true_sentence = \"The quick brown fox jumps over the lazy dog\"\n",
    "sentence_1 = \"The fast black cat jumps over the lazy dog\"\n",
    "\n",
    "s1_word_proba = [0.99, 0.85, 0.89, 0.94, 0.99, 0.99, 0.99, 0.99, 0.90]\n",
    "perplexity = calculate_perplexity(s1_word_proba)\n",
    "print(\"Perplexity sentence 1:\", perplexity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b50c6432-b7e2-4400-8ef1-4ed393019444",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perplexity sentence 2: 2.2188609051008896\n"
     ]
    }
   ],
   "source": [
    "sentence_2 = \"The bold orange car drove by the lazy dog\"\n",
    "\n",
    "s2_word_proba = [0.99, 0.65, 0.13, 0.05, 0.21, 0.99, 0.99, 0.99, 0.90]\n",
    "perplexity = calculate_perplexity(s2_word_proba)\n",
    "print(\"Perplexity sentence 2:\", perplexity)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "747e4746-db4f-4dce-9232-fd681ccdc463",
   "metadata": {},
   "source": [
    "## Relationship to Cross Entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "30d5ea03-5576-4ac3-9acd-12deae2004ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cross_entropy(p, q):\n",
    "    # Clip q to avoid log2(0) which is undefined\n",
    "    q = np.clip(q, 1e-10, 1.0)\n",
    "    H = -np.sum(p * np.log2(q))\n",
    "    \n",
    "    return H\n",
    "\n",
    "n = len(s1_word_proba)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fd3a9a63-2f89-45e9-a088-c97edd6744e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7163562924630626"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cross_entropy(np.ones(n), s1_word_proba)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "77d3c672-5688-4131-a093-6ed099398f10",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0567214564189926"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "2**(cross_entropy(np.ones(n), s1_word_proba) / n )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "678e6e34-2e7c-41e8-939d-7e061e06e5c0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0567214564189926"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calculate_perplexity(s1_word_proba)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f3dcf75-5218-4ed0-95cc-7b4b726827e8",
   "metadata": {},
   "source": [
    "## Perplexity with TorchMetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "481d91c8-270b-4c05-934c-1976e3e2f890",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchmetrics.text import Perplexity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9669c09-8409-4d30-abb6-67676ea8a545",
   "metadata": {},
   "source": [
    "Torchmetrics' perplexity takes in a `predictions` and a `target` variable. \n",
    "\n",
    "For the `predictions` it assumes the shape `[batch_size, seq_len, vocab_size]`, and for the targets it assumes the shape `[batch_size, seq_len]`.\n",
    "\n",
    "if we are only looking at one sentence, we have a batch size of 1.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dc1de00f-d527-46cb-8db9-4bd6a59463c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "sentence_1 = \"The fast black cat jumps over the lazy dog\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "262a7b52-af5e-4a68-b969-1735bfec4bd5",
   "metadata": {},
   "source": [
    "Now, in this notebook, we haven't constructed a vocabulary, which is the set of all unique words in the training set. For simplicity, let's assume the vocabulary contains the following words:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c5c781d9-d1ab-4d71-a4ca-68397880b8de",
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = {\n",
    "    0: \"The\",\n",
    "    1: \"quick\",\n",
    "    2: \"brown\",\n",
    "    3: \"fox\",\n",
    "    4: \"jumps\",\n",
    "    5: \"over\",\n",
    "    6: \"the\",\n",
    "    7: \"lazy\",\n",
    "    8: \"dog\",\n",
    "    9: \"fast\",\n",
    "    10: \"black\",\n",
    "    11: \"cat\",\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "289e9648-1270-4c79-b7c5-be6aa3f4ebf0",
   "metadata": {},
   "source": [
    "Since the vocabulary has 12 words, each word output by the model would be a 12-dimensional probability vector. So, for a sentence consisting of 9 words (\"The fast black cat jumps over the lazy dog\") we have a 1x9x12 dimensional tensor.\n",
    "\n",
    "Also, previously, we considerded the word probabilities \n",
    "\n",
    "```python\n",
    "s1_word_proba = [0.99, 0.85, 0.89, 0.99, 0.99, 0.99, 0.99, 0.99]\n",
    "```\n",
    "\n",
    "In the representation below, the vocabulary index corresponding to the word at that position will have that probability value.\n",
    "\n",
    "```python\n",
    "\n",
    "vocab = {\n",
    "    0: \"The\",\n",
    "    1: \"quick\",\n",
    "    2: \"brown\",\n",
    "    3: \"fox\",\n",
    "    4: \"jumps\",\n",
    "    5: \"over\",\n",
    "    6: \"the\",\n",
    "    7: \"lazy\",\n",
    "    8: \"dog\",\n",
    "    9: \"fast\",\n",
    "    10: \"black\",\n",
    "    11: \"cat\",\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8e71aa61-d3cd-4cb2-b271-5a766f9ceaf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "model_outputs = torch.tensor([[\n",
    "    [0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01], # The, index 0\n",
    "    [0.0, 0.0, 0.0, 0.0, 0.02, 0.05, 0.02, 0.01, 0.05, 0.85, 0.00, 0.00], # fast, index 9\n",
    "    [0.01, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.89, 0.0], # black, 10\n",
    "    [0.0, 0.0, 0.0, 0.01, 0.0, 0.05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.94], # cat, 11\n",
    "    [0.0, 0.01, 0.0, 0.0, 0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # jumps, 4\n",
    "    [0.0, 0.0, 0.005, 0.005, 0.0, 0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # over, 5\n",
    "    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.99, 0.0, 0.0, 0.01, 0.0, 0.0], # the, 6\n",
    "    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.99, 0.01, 0.0, 0.0, 0.0], # lazy, 7\n",
    "    [0.0, 0.0, 0.0, 0.0, 0.0, 0.05, 0.04, 0.0, 0.90, 0.0, 0.0, 0.01], # dog, 8\n",
    "]])\n",
    "\n",
    "# rows should sum to 1\n",
    "print(model_outputs.sum(axis=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb2bed18-f7c6-4d5c-8967-d3d6685b1bf3",
   "metadata": {},
   "source": [
    "Note that the list of vectors above may represent the probability vectors returned by an LLM, for example. One vector per word. The probabilities in each row should sum up to one.\n",
    "\n",
    "For example, looking at the first row\n",
    "\n",
    "```\n",
    "[0.99, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01], # The, index 0\n",
    "```\n",
    "\n",
    "this means the model assings a probability of 0.99 to the first word, 0.99. The probabilities for the other words is 0 except for the last word (\"cat\"), which is 0.01 in this case.\n",
    "\n",
    "**Note that these probabilities are abitrarily assigned by me. In an application, they would be returned by an actual LLM, which we omit here for simplicity.**\n",
    "\n",
    "\n",
    "\n",
    "Then, with the target vector containing the word indices, we can garner these probabilities corresponding to the target word indices:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "df3e9250-f033-4a19-906a-267db2508f26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0.9900],\n",
      "         [0.8500],\n",
      "         [0.8900],\n",
      "         [0.9400],\n",
      "         [0.9900],\n",
      "         [0.9900],\n",
      "         [0.9900],\n",
      "         [0.9900],\n",
      "         [0.9000]]])\n"
     ]
    }
   ],
   "source": [
    "targets = torch.tensor([[0, 9, 10, 11, 4, 5, 6, 7, 8]])\n",
    "\n",
    "# Gather the probabilities\n",
    "probabilities = torch.gather(model_outputs, 2, targets.unsqueeze(2))\n",
    "\n",
    "print(probabilities)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c34d790d-3353-4185-a388-688534604875",
   "metadata": {},
   "source": [
    "According to the [TorchMetric perplexity documentation](https://torchmetrics.readthedocs.io/en/stable/text/perplexity.html), the input is a probability score, \n",
    "\n",
    "> - ``preds`` (:class:`~torch.Tensor`): Probabilities assigned to each token in a sequence with shape\n",
    "    [batch_size, seq_len, vocab_size]\n",
    "\n",
    "but the results are inflated when providing the inputs directly. However, when providing log-probabilities, we can reproduce the results from earlier:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5f4420eb-2cac-4baf-931b-1415fa152226",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torchmetrics version: 0.11.4\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(1.0567)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torchmetrics\n",
    "from torchmetrics.text import Perplexity\n",
    "\n",
    "print(\"torchmetrics version:\", torchmetrics.__version__)\n",
    "\n",
    "perp = Perplexity()\n",
    "perp(torch.log(model_outputs), targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04c1ed8c-ddbe-4902-bb56-53f2d9ecc77e",
   "metadata": {
    "tags": []
   },
   "source": [
    "## PyTorch Built-Ins"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59122c21-dd97-417a-9221-fbb96523b104",
   "metadata": {},
   "source": [
    "Note that PyTorch's `torch.nn.functional.cross_entropy` works with logits, so we are using the negative log-likelihood loss, which assumes probabilities as inputs (usually from `torch.log_softmax(logits)`).\n",
    "\n",
    "In practice, if your model returns logits (instead of probabilities), you may want to use\n",
    "`torch.nn.functional.cross_entropy` instead of `torch.nn.functional.nll_loss` for better numerical stability and efficiency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2de93e09-f0c2-4bf0-9e8b-539f7d070f65",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([9, 12])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_outputs[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "78682c83-7d77-4f7f-98d0-eb271a55540b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 9])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "targets.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "0780aa39-4e34-45cb-9fde-180e3fac6e6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0567214488983154"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def pytorch_perplexity(prob, target):\n",
    "\n",
    "    log_prob = torch.log(prob)\n",
    "    loss = F.nll_loss(log_prob, target, reduction='mean')\n",
    "    perplexity = torch.exp(loss)\n",
    "    return perplexity.item()\n",
    "\n",
    "pytorch_perplexity(model_outputs[0], targets[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad1f198a-ca6c-4fc8-a1b6-f32916e12cbf",
   "metadata": {},
   "source": [
    "## Perplexity with log base 2 and natural log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "76c45995-f536-4380-8c90-43b10686a202",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perplexity sentence 1: 1.0567214564189926\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "def calculate_perplexity_base2(probabilities):\n",
    "    log_probs = np.log2(probabilities)\n",
    "    avg_log_prob = np.mean(log_probs)\n",
    "    perplexity = 2 ** (-avg_log_prob)\n",
    "    return perplexity\n",
    "\n",
    "true_sentence = \"The quick brown fox jumps over the lazy dog\"\n",
    "sentence_1 = \"The fast black cat jumps over the lazy dog\"\n",
    "\n",
    "s1_word_proba = [0.99, 0.85, 0.89, 0.94, 0.99, 0.99, 0.99, 0.99, 0.90]\n",
    "perplex = calculate_perplexity_base2(s1_word_proba)\n",
    "print(\"Perplexity sentence 1:\", perplex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0769ae64-0363-4987-af5e-167bdb9ce773",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perplexity sentence 1: 1.0567214564189926\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "def calculate_perplexity_natural(probabilities):\n",
    "    log_probs = np.log(probabilities)\n",
    "    avg_log_prob = np.mean(log_probs)\n",
    "    perplexity = np.e ** (-avg_log_prob)\n",
    "    return perplexity\n",
    "\n",
    "true_sentence = \"The quick brown fox jumps over the lazy dog\"\n",
    "sentence_1 = \"The fast black cat jumps over the lazy dog\"\n",
    "\n",
    "s1_word_proba = [0.99, 0.85, 0.89, 0.94, 0.99, 0.99, 0.99, 0.99, 0.90]\n",
    "perplex = calculate_perplexity_natural(s1_word_proba)\n",
    "print(\"Perplexity sentence 1:\", perplex)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
